#include "PPU.h"
#include "Log.h"

PNE_NAMESPACE_BEGIN

PPU::PPU(PictureBus &bus, VirtualScreen &screen) : m_bus(bus),
                                                   m_screen(screen),
                                                   m_spriteMemory(64 * 4),
                                                   m_pictureBuffer(ScanlineVisibleDots, std::vector<sf::Color>(VisibleScanlines, sf::Color::Magenta))
{
}

void PPU::Reset()
{
    m_longSprites = m_generateInterrupt = m_greyscaleMode = m_vblank = m_spriteOverflow = false;
    m_showBackground = m_showSprites = m_evenFrame = m_firstWrite = true;
    m_bgPage = m_sprPage = CharacterPage::Low;
    m_dataAddress = m_cycle = m_scanline = m_spriteDataAddress = m_fineXScroll = m_tempAddress = 0;

    m_dataAddrIncrement = 1;
    m_pipelineState = State::PreRender;
    m_scanlineSprites.reserve(8);
    m_scanlineSprites.resize(0);
}

void PPU::SetInterruptCallback(std::function<void(void)> callback)
{
    m_vblanCallback = callback;
}

void PPU::Step()
{
    switch (m_pipelineState)
    {
    case State::PreRender:
        if (m_cycle == 1)
            m_vblank = m_sprZeroHit = false;
        else if (m_cycle == ScanlineVisibleDots + 2 && m_showBackground && m_showSprites)
        {
            // Set bits related to horizontal position
            m_dataAddress &= ~0x41f;                // Unset horizontal bits
            m_dataAddress |= m_tempAddress & 0x41f; // Copy
        }
        else if (m_cycle > 280 && m_cycle <= 304 && m_showBackground && m_showSprites)
        {
            // Set vertical bits
            m_dataAddress &= ~0x7be0;                // Unset bits related to horizontal
            m_dataAddress |= m_tempAddress & 0x7be0; // Copy
        }
        //                 if (m_cycle > 257 && m_cycle < 320)
        //                     m_spriteDataAddress = 0;
        // if rendering is on, every other frame is one cycle shorter
        if (m_cycle >= ScanlineEndCycle - (!m_evenFrame && m_showBackground && m_showSprites))
        {
            m_pipelineState = State::Render;
            m_cycle = m_scanline = 0;
        }

        // add IRQ support for MMC3
        if (m_cycle == 260 && m_showBackground && m_showSprites)
        {
            m_bus.ScanlineIRQ();
        }
        break;
    case State::Render:
        if (m_cycle > 0 && m_cycle <= ScanlineVisibleDots)
        {
            Byte bgColor = 0, sprColor = 0;
            bool bgOpaque = false, sprOpaque = true;
            bool spriteForeground = false;

            int x = m_cycle - 1;
            int y = m_scanline;

            if (m_showBackground)
            {
                auto x_fine = (m_fineXScroll + x) % 8;
                if (!m_hideEdgeBackground || x >= 8)
                {
                    // fetch tile
                    auto addr = 0x2000 | (m_dataAddress & 0x0FFF); // mask off fine y
                    // auto addr = 0x2000 + x / 8 + (y / 8) * (ScanlineVisibleDots / 8);
                    Byte tile = read(addr);

                    // fetch pattern
                    // Each pattern occupies 16 bytes, so multiply by 16
                    addr = (tile * 16) + ((m_dataAddress >> 12 /*y % 8*/) & 0x7); // Add fine y
                    addr |= static_cast<int>(m_bgPage) << 12;                     // set whether the pattern is in the high or low page
                    // Get the corresponding bit determined by (8 - x_fine) from the right
                    bgColor = (read(addr) >> (7 ^ x_fine)) & 1;             // bit 0 of palette entry
                    bgColor |= ((read(addr + 8) >> (7 ^ x_fine)) & 1) << 1; // bit 1

                    bgOpaque = bgColor; // flag used to calculate final pixel with the sprite pixel

                    // fetch attribute and calculate higher two bits of palette
                    addr = 0x23C0 | (m_dataAddress & 0x0C00) | ((m_dataAddress >> 4) & 0x38) | ((m_dataAddress >> 2) & 0x07);
                    auto attribute = read(addr);
                    int shift = ((m_dataAddress >> 4) & 4) | (m_dataAddress & 2);
                    // Extract and set the upper two bits for the color
                    bgColor |= ((attribute >> shift) & 0x3) << 2;
                }
                // Increment/wrap coarse X
                if (x_fine == 7)
                {
                    if ((m_dataAddress & 0x001F) == 31) // if coarse X == 31
                    {
                        m_dataAddress &= ~0x001F; // coarse X = 0
                        m_dataAddress ^= 0x0400;  // switch horizontal nametable
                    }
                    else
                    {
                        m_dataAddress += 1; // increment coarse X
                    }
                }
            }

            if (m_showSprites && (!m_hideEdgeSprites || x >= 8))
            {
                for (auto i : m_scanlineSprites)
                {
                    Byte spr_x = m_spriteMemory[i * 4 + 3];

                    if (0 > x - spr_x || x - spr_x >= 8)
                        continue;

                    Byte spr_y = m_spriteMemory[i * 4 + 0] + 1,
                         tile = m_spriteMemory[i * 4 + 1],
                         attribute = m_spriteMemory[i * 4 + 2];

                    int length = (m_longSprites) ? 16 : 8;

                    int x_shift = (x - spr_x) % 8, y_offset = (y - spr_y) % length;

                    if ((attribute & 0x40) == 0) // If NOT flipping horizontally
                        x_shift ^= 7;
                    if ((attribute & 0x80) != 0) // IF flipping vertically
                        y_offset ^= (length - 1);

                    Address addr = 0;

                    if (!m_longSprites)
                    {
                        addr = tile * 16 + y_offset;
                        if (m_sprPage == CharacterPage::High)
                            addr += 0x1000;
                    }
                    else // 8x16 sprites
                    {
                        // bit-3 is one if it is the bottom tile of the sprite, multiply by two to get the next pattern
                        y_offset = (y_offset & 7) | ((y_offset & 8) << 1);
                        addr = (tile >> 1) * 32 + y_offset;
                        addr |= (tile & 1) << 12; // Bank 0x1000 if bit-0 is high
                    }

                    sprColor |= (read(addr) >> (x_shift)) & 1;            // bit 0 of palette entry
                    sprColor |= ((read(addr + 8) >> (x_shift)) & 1) << 1; // bit 1

                    if (!(sprOpaque = sprColor))
                    {
                        sprColor = 0;
                        continue;
                    }

                    sprColor |= 0x10;                   // Select sprite palette
                    sprColor |= (attribute & 0x3) << 2; // bits 2-3

                    spriteForeground = !(attribute & 0x20);

                    // Sprite-0 hit detection
                    if (!m_sprZeroHit && m_showBackground && i == 0 && sprOpaque && bgOpaque)
                    {
                        m_sprZeroHit = true;
                    }

                    break; // Exit the loop now since we've found the highest priority sprite
                }
            }

            Byte paletteAddr = bgColor;

            if ((!bgOpaque && sprOpaque) ||
                (bgOpaque && sprOpaque && spriteForeground))
                paletteAddr = sprColor;
            else if (!bgOpaque && !sprOpaque)
                paletteAddr = 0;
            // else bgColor

            m_pictureBuffer[x][y] = sf::Color(colors[m_bus.ReadPalette(paletteAddr)]);
        }
        else if (m_cycle == ScanlineVisibleDots + 1 && m_showBackground)
        {
            // Shamelessly copied from nesdev wiki
            if ((m_dataAddress & 0x7000) != 0x7000) // if fine Y < 7
                m_dataAddress += 0x1000;            // increment fine Y
            else
            {
                m_dataAddress &= ~0x7000;              // fine Y = 0
                int y = (m_dataAddress & 0x03E0) >> 5; // let y = coarse Y
                if (y == 29)
                {
                    y = 0;                   // coarse Y = 0
                    m_dataAddress ^= 0x0800; // switch vertical nametable
                }
                else if (y == 31)
                    y = 0; // coarse Y = 0, nametable not switched
                else
                    y += 1; // increment coarse Y
                m_dataAddress = (m_dataAddress & ~0x03E0) | (y << 5);
                // put coarse Y back into m_dataAddress
            }
        }
        else if (m_cycle == ScanlineVisibleDots + 2 && m_showBackground && m_showSprites)
        {
            // Copy bits related to horizontal position
            m_dataAddress &= ~0x41f;
            m_dataAddress |= m_tempAddress & 0x41f;
        }

        //                 if (m_cycle > 257 && m_cycle < 320)
        //                     m_spriteDataAddress = 0;

        // add IRQ support for MMC3
        if (m_cycle == 260 && m_showBackground && m_showSprites)
        {
            m_bus.ScanlineIRQ();
        }

        if (m_cycle >= ScanlineEndCycle)
        {
            // Find and index sprites that are on the next Scanline
            // This isn't where/when this indexing, actually copying in 2C02 is done
            // but (I think) it shouldn't hurt any games if this is done here

            m_scanlineSprites.resize(0);

            int range = 8;
            if (m_longSprites)
            {
                range = 16;
            }

            std::size_t j = 0;
            for (std::size_t i = m_spriteDataAddress / 4; i < 64; ++i)
            {
                auto diff = (m_scanline - m_spriteMemory[i * 4]);
                if (0 <= diff && diff < range)
                {
                    if (j >= 8)
                    {
                        m_spriteOverflow = true;
                        break;
                    }
                    m_scanlineSprites.push_back(i);
                    ++j;
                }
            }

            ++m_scanline;
            m_cycle = 0;
        }

        if (m_scanline >= VisibleScanlines)
            m_pipelineState = State::PostRender;

        break;
    case State::PostRender:
        if (m_cycle >= ScanlineEndCycle)
        {
            ++m_scanline;
            m_cycle = 0;
            m_pipelineState = State::VerticalBlank;

            for (std::size_t x = 0; x < m_pictureBuffer.size(); ++x)
            {
                for (std::size_t y = 0; y < m_pictureBuffer[0].size(); ++y)
                {
                    m_screen.SetPixel(x, y, m_pictureBuffer[x][y]);
                }
            }
        }

        break;
    case State::VerticalBlank:
        if (m_cycle == 1 && m_scanline == VisibleScanlines + 1)
        {
            m_vblank = true;
            if (m_generateInterrupt)
                m_vblanCallback();
        }

        if (m_cycle >= ScanlineEndCycle)
        {
            ++m_scanline;
            m_cycle = 0;
        }

        if (m_scanline >= FrameEndScanline)
        {
            m_pipelineState = State::PreRender;
            m_scanline = 0;
            m_evenFrame = !m_evenFrame;
        }

        break;
    default:
        LOG(Level::Error) << "Well, this shouldn't have happened." << std::endl;
    }

    ++m_cycle;
}

Byte PPU::readOAM(Byte addr)
{
    return m_spriteMemory[addr];
}

void PPU::writeOAM(Byte addr, Byte value)
{
    m_spriteMemory[addr] = value;
}

void PPU::DoDMA(const Byte *page_ptr)
{
    std::memcpy(m_spriteMemory.data() + m_spriteDataAddress, page_ptr, 256 - m_spriteDataAddress);
    if (m_spriteDataAddress)
        std::memcpy(m_spriteMemory.data(), page_ptr + (256 - m_spriteDataAddress), m_spriteDataAddress);
}

void PPU::Control(Byte ctrl)
{
    m_generateInterrupt = ctrl & 0x80;
    m_longSprites = ctrl & 0x20;
    m_bgPage = static_cast<CharacterPage>(
        !(!(ctrl & 0x10)));
    m_sprPage = static_cast<CharacterPage>(
        !(!(ctrl & 0x8)));

    if (ctrl & 0x4)
        m_dataAddrIncrement = 0x20;
    else
        m_dataAddrIncrement = 1;

    m_tempAddress &= ~0xc00;
    m_tempAddress |= (ctrl & 0x3) << 10;
}

void PPU::SetMask(Byte mask)
{
    m_greyscaleMode = mask & 0x1;
    m_hideEdgeBackground = !(mask & 0x2);
    m_hideEdgeSprites = !(mask & 0x4);
    m_showBackground = mask & 0x8;
    m_showSprites = mask & 0x10;
}

Byte PPU::GetStatus()
{
    Byte status = m_spriteOverflow << 5 |
                  m_sprZeroHit << 6 |
                  m_vblank << 7;
    m_vblank = false;
    m_firstWrite = true;
    return status;
}

void PPU::SetDataAddress(Byte addr)
{
    if (m_firstWrite)
    {
        m_tempAddress &= ~0xff00;
        m_tempAddress |= (addr & 0x3f) << 8;
        m_firstWrite = false;
    }
    else
    {
        m_tempAddress &= ~0xff;
        m_tempAddress |= addr;
        m_dataAddress = m_tempAddress;
        m_firstWrite = true;
    }
}

Byte PPU::GetData()
{
    auto data = m_bus.Read(m_dataAddress);
    m_dataAddress += m_dataAddrIncrement;

    // Reads are delayed by one byte
    // read when address is in this range
    if (m_dataAddress < 0x3f00)
        std::swap(data, m_dataBuffer);

    return data;
}

Byte PPU::GetOAMData()
{
    return readOAM(m_spriteDataAddress);
}

void PPU::SetData(Byte data)
{
    m_bus.Write(m_dataAddress, data);
    m_dataAddress += m_dataAddrIncrement;
}

void PPU::SetOAMAddress(Byte addr)
{
    m_spriteDataAddress = addr;
}

void PPU::SetOAMData(Byte value)
{
    writeOAM(m_spriteDataAddress++, value);
}

void PPU::SetScroll(Byte scroll)
{
    if (m_firstWrite)
    {
        m_tempAddress &= ~0x1f;
        m_tempAddress |= (scroll >> 3) & 0x1f;
        m_fineXScroll = scroll & 0x7;
        m_firstWrite = false;
    }
    else
    {
        m_tempAddress &= ~0x73e0;
        m_tempAddress |= ((scroll & 0x7) << 12) |
                         ((scroll & 0xf8) << 2);
        m_firstWrite = true;
    }
}

Byte PPU::read(Address addr)
{
    return m_bus.Read(addr);
}

PNE_NAMESPACE_END
// DONE